import ephere_moov as moov
import ephere_ornatrix as ox

from HairModel import HairModel, ModelType
from SolverUpdater import GetSolverUpdater

_hairSimulator = None
_solver = None

# Immutable tuples defining the default parameters: ( name, type, default value [, min, max] )
# MoovDefaultParameters = [ ( 'param1', 'float', 0.0 ),
#							( 'param2', 'int', 0 ), 
#							( 'param3', 'string', 'initial' ) ]
# min and max values can be specified for numeric atributes after the initial value
# The actual values of the parameters cannot be read from here - this is only for initial defaults.
MoovDefaultParameters = (
			   ( 'StrandGroup', 'string', '' ),
			   # Parameter groups are defined via the "Group" type; True/False means collapsed/expanded by default
			   # Solver parameters
			   ( 'SolverParameters', "Group", True ),
			   ( 'SubstepCount', 'int', 10, 1, 20 ),
			   ( 'IterationCount', 'int', 5, 1, 20 ),
			   ( 'VelocityLimit', 'float', 10000.0, 0.1, 1000000000.0 ),
			   ( 'CollisionTolerance', 'float', 0.1, 0.01, 2 ),
			   # Model parameters
			   ( 'HairModel', "Group", False ),
			   ( 'LatticeCount', 'int', 1, 1, 3 ),
			   ( 'LatticeSize', 'float', 0.1, 0.01, 0.5 ),
			   ( 'LatticeStiffness', 'float', 1, 0.0, 1 ),
			   ( 'Sep1', 'Separator' ),
			   ( 'ModelType', 'enum', 2, ModelType.DistanceOnly, ModelType.DistanceBending, ModelType.Cosserat, ModelType.CosseratDistance ),
			   ( 'UseCompliantConstraints', 'bool', True ),
			   ( 'LimitStretch', 'bool', False ),
			   ( 'SlowMovingRoots', 'bool', False ),
			   ( 'Sep2', 'Separator' ),
			   ( 'LongRangeLayers', 'int', 0, 0, 6 ),
			   ( 'LongRangeStiffness', 'float', 1.0, 0.0, 1 ),
			   # Hair properties
			   ( 'HairProperties', "Group", False ),
			   ( 'StretchingStiffness', 'float', 0.5, 0.0, 1 ),
			   ( 'StretchingStiffnessChannel', 'ChannelSelector' ),
			   ( 'StretchingStiffnessCurve', 'RampCurve', 0, 1, 3,   0.5, 1, 3,   1, 1, 3 ),
			   ( 'BendingStiffness', 'float', 0.5, 0.0, 1 ),
			   ( 'BendingStiffnessChannel', 'ChannelSelector' ),
			   ( 'BendingStiffnessCurve', 'RampCurve', 0, 1, 3,   0.5, 1, 3,   1, 1, 3 ),
			   ( 'RootStiffness', 'float', 1, 0.0, 1 ),
			   ( 'RootVertexCount', 'int', 0 ),
			   #( 'MassPerVertex', 'float', 1.0, 0.1, 10 ),
			   ( 'MassChannel', 'ChannelSelector' ),
			   ( 'MassCurve', 'RampCurve', 0, 1, 3,   0.5, 1, 3,   1, 1, 3 ),
			   # Parameters affecting forces/fields
			   ( 'ForcesAndFields', "Group", False ),
			   ( 'Gravity', 'float', -981, -1000, 0 ),
			   ( 'GravityScale', 'float', 1.0, 0, 100 ),
			   ( 'ForceFieldScale', 'float', 1.0, 0, 1 ),
			   ( 'ForceFieldChannel', 'ChannelSelector' ),
			   ( 'ForceFieldCurve', 'RampCurve', 0, 1, 3,   0.5, 1, 3,   1, 1, 3 ),
			   ( 'Drag', 'float', 0.01, 0, 1 ),
			   ( 'AttractToInitialShape', 'bool', False ),
			   ( 'AttractToInitialShape_Stiffness', 'float', 0.5, 0, 10 ),
			   ( 'AttractToInitialShapeChannel', 'ChannelSelector' ),
			   # Ramp initialization control points; each point is float position (0-root, 1-tip), float value (0-1), int interpolation type. 
			   ( 'AttractToInitialShape_Ramp', 'RampCurve', 0, 1, 3,   0.5, 1, 3,   1, 1, 3 ),
			   # Parameters affecting collisions
			   ( 'Collisions', "Group", False ),
			   ( 'CollideWithBaseMesh', 'bool', False ),
			   ( 'CollideWithMeshes', 'bool', True ),
			   ( 'CollideWithHair', 'bool', False ),
			   ( 'CollideWithHairRadius', 'float', 0.1 ),
			   ( 'ParticleRadius', 'float', 0 ),
			   ( 'ParticleRadiusChannel', 'ChannelSelector' ),
			   ( 'ParticleRadiusCurve', 'RampCurve', 0, 1, 3,   0.5, 1, 3,   1, 1, 3 ),
			   ( 'FrictionCoefficient', 'float', 0.5, 0, 1 ),
			   ( 'RestitutionCoefficient', 'float', 0, 0, 1 ),
			   # Group holder parameters
			   ( 'GroupHolder', "Group", True ),
			   ( 'UseGroupHolder', 'bool', False ),
			   ( 'GroupHolderPosMin', 'int', 4 ),
			   ( 'GroupHolderPosMax', 'int', 6 ),
			   ( 'GroupStiffness', 'float', 0.2, 0, 2 ),
			   ( 'GroupRandomSeed', 'int', 1000 ),
			   ( 'GroupMaxConstraints', 'int', 100, 10, 1000 ),
			   # Attachment parameters
			   ( 'Attachments', "Group", True ),
			   ( 'AttachmentMeshes', 'DataGenerator' ),
			   ( 'AttachmentPerObject', 'bool', False ),
			   ( 'AttachmentStiffness', 'float', 0.5, 0, 1 ),
			   ( 'AttachmentDensity', 'float', 0.1, 0, 1 ),
			   ( 'AttachmentReleaseTime', 'time', 10 ),
			   )


# Lists of parameters requiring different actions in UpdateParameters.
solverParams = ['IterationCount', 'VelocityLimit', 'CollisionTolerance']
stiffnessParams = ['LatticeStiffness', 'StretchingStiffness', 'BendingStiffness', 'StretchingStiffnessCurve', 'BendingStiffnessCurve', 'StretchingStiffnessChannel', 'BendingStiffnessChannel', 'LongRangeStiffness']
collisionParams = [param[0] for param in MoovDefaultParameters if param[0].startswith( 'Collide' )]
resetParams = ['StrandGroup', 'LatticeCount', 'LatticeSize', 'ModelType', 'UseCompliantConstraints', 'LimitStretch', 'RootStiffness', 'LongRangeLayers']
groupHolderParams = [param[0] for param in MoovDefaultParameters if param[0].startswith( 'Group' )]
groupHolderParams.append( 'UseGroupHolder' )
# AttractToInitialShape affects implicitly the choice of update kernel
updateKernelParams =  ['Gravity', 'GravityScale', 'Drag', 'AttractToInitialShape']


class HairSimulator( object ):
	def __init__(self, solver = None ):
		super( HairSimulator, self ).__init__()
		self.isFirstEvaluation = True
		self.oldTime = None
		self.firstTime = None
		self.frameTimeStep = None
		self.solver = solver if solver is not None else moov.Solver()
		self.solverUpdater = GetSolverUpdater( self.solver, tryOpenCLFirst = True )
		self.hairModel = HairModel( self.solver )
		self.meshSolverIdToIndex = {}
		self.inputParamsDict = None
		self.oldInputParamsDict = None

	def Initialize( self, inputHair, inputParamsDict ):
		self.hairModel.ClearSolverObjects()
		self.inputParamsDict = inputParamsDict
		self.oldInputParamsDict = inputParamsDict.copy()
		self.Reset( inputHair )

	def Reset( self, inputHair ):
		self.hairModel.ClearSolverObjects()
		self.hairModel.SetStrandGroupSet( self.inputParamsDict['StrandGroup'] )

		self.hairModel.SetSolverParameters(
			positionIterCount = self.inputParamsDict['IterationCount'], maxSpeed = self.inputParamsDict['VelocityLimit'], 
			collisionTolerance = self.inputParamsDict['CollisionTolerance'],
			pbdParticleFrictionCoefficient = self.inputParamsDict['FrictionCoefficient'],
			pbdParticleRestitutionCoefficient = self.inputParamsDict['RestitutionCoefficient'] )

		useRootHolder = ( self.inputParamsDict['RootStiffness'] < 0.98 )

		self.hairModel.params.SetModelParameters( 
			latticeCount = self.inputParamsDict['LatticeCount'],
			latticeSize = self.inputParamsDict['LatticeSize'],
			latticeStiffness = self.inputParamsDict['LatticeStiffness'],
			modelType = self.inputParamsDict['ModelType'],
			useCompliantConstraints = self.inputParamsDict['UseCompliantConstraints'],
			limitStretch = self.inputParamsDict['LimitStretch'],
			longRangeLayerCount = self.inputParamsDict['LongRangeLayers'],
			longRangeStiffness = self.inputParamsDict['LongRangeStiffness'] )

		self.hairModel.params.SetHairParameters( 
			stretchingStiffness = self.inputParamsDict['StretchingStiffness'],
			bendingStiffness = self.inputParamsDict['BendingStiffness'],
			#massPerVertex = self.inputParamsDict['MassPerVertex'],
			useRootHolder = useRootHolder,
			rootHolderPosition = self.inputParamsDict['RootStiffness'],
			rootVertexCount = self.inputParamsDict['RootVertexCount'],
			massCurve = self.inputParamsDict['MassCurve'],
			stretchingCurve = self.inputParamsDict['StretchingStiffnessCurve'],
			bendingCurve = self.inputParamsDict['BendingStiffnessCurve'],
			stretchingChannel = self.inputParamsDict['StretchingStiffnessChannel'],
			bendingChannel = self.inputParamsDict['BendingStiffnessChannel'],
			massChannel = self.inputParamsDict['MassChannel'] )

		self.hairModel.params.SetCapsuleParameters( 
			capsuleCollisionGroup = 1000,
			capsuleRadius = self.inputParamsDict['CollideWithHairRadius'] )

		self.hairModel.params.SetCollisionParameters( 
			particleRadius = self.inputParamsDict['ParticleRadius'],
			particleRadiusChannel = self.inputParamsDict['ParticleRadiusChannel'],
			particleRadiusCurve = self.inputParamsDict['ParticleRadiusCurve'],
			meshFrictionCoefficient = self.inputParamsDict['FrictionCoefficient'],
			meshRestitutionCoefficient = self.inputParamsDict['RestitutionCoefficient'] )

		self.hairModel.params.SetGroupHolderParameters(
			useGroupHolder = self.inputParamsDict['UseGroupHolder'], 
			groupHolderGenerator = 'Random', 
			groupHolderPosMin = self.inputParamsDict['GroupHolderPosMin'], 
			groupHolderPosMax = self.inputParamsDict['GroupHolderPosMax'], 
			groupHolderMaxGroupCount = self.inputParamsDict['GroupMaxConstraints'], 
			groupHolderRandomSeed = self.inputParamsDict['GroupRandomSeed'], 
			groupHolderStiffness = self.inputParamsDict['GroupStiffness'] )

		self.hairModel.SetHair( inputHair )
		self.hairModel.InitializeStrands()
		self.hairModel.CreateHairModel()

		if self.inputParamsDict['AttachmentMeshes'].GetObjectCount() > 0:
			self.hairModel.CreateAttachment( self.inputParamsDict['AttachmentMeshes'], self.inputParamsDict['AttachmentStiffness'], 
									self.inputParamsDict['AttachmentDensity'], self.inputParamsDict['AttachmentPerObject'] )

		self.solverUpdater.Reset( self.hairModel.dynamicParticles, self.hairModel.rootParticles )

		self.meshSolverIdToIndex = {}
		for index in range( len( self.hairModel.meshParticleIds ) ):
			self.meshSolverIdToIndex[self.hairModel.meshParticleIds[index]] = index

		self.hairModel.EnableCollisions( self.inputParamsDict['CollideWithBaseMesh'], self.inputParamsDict['CollideWithMeshes'], self.inputParamsDict['CollideWithHair'] )

		self.solverUpdater.SetUpdateKernelParams( self.inputParamsDict['Gravity'] / self.inputParamsDict['GravityScale'], self.inputParamsDict['Drag'] )

		self.oldRootPositions = self.hairModel.GetHairRootPositions()
		self.oldRootOrientations = self.hairModel.GetHairRootOrientations( self.oldRootPositions )
		self.interpolatedRootPositions = None
		self.interpolatedRootOrientations = None


	def updateMesh( self, pd, positions, rotationsImag, rotationsReal ):
		index = self.meshSolverIdToIndex[pd.id]
		pd.x = positions[index]
		pd.rotation = moov.Quaternion( rotationsReal[index], rotationsImag[index] )


	def PreStep_UpdateForces( self, time, step, hair ):
		hasExternalForces = hair.HasExternalForces()
		hasInitialForces = self.inputParamsDict['AttractToInitialShape']

		if hasExternalForces or hasInitialForces:
			externalForces = self.hairModel.GetExternalForces( self.hairModel.dynamicParticles, step, forceMultiplier = self.inputParamsDict['ForceFieldScale'], time = time ) if hasExternalForces else None
			externalForces = self.hairModel.GetValuesTimesRootToTipMultipliers( externalForces, self.inputParamsDict['ForceFieldCurve'], self.inputParamsDict['ForceFieldChannel'] )
			initialPositions, rampMultipliers = self.hairModel.GetHairInitialPositionsAndRampMultipliers( self.inputParamsDict['AttractToInitialShape_Ramp'], self.inputParamsDict['AttractToInitialShapeChannel'] ) if hasInitialForces else ( None, None )
			self.solverUpdater.SetForces( externalForces, self.inputParamsDict['AttractToInitialShape_Stiffness'], initialPositions, rampMultipliers )


	def PreStep( self, time, step, hair ):
		'''Called once in the beginning of each frame.'''
		self.PreStep_UpdateForces( time, step, hair )

		# Update roots
		self.rootPositions = self.hairModel.GetHairRootPositions()
		self.rootOrientations = self.hairModel.GetHairRootOrientations( self.rootPositions, self.oldRootOrientations )
		#self.rootOrientations = self.hairModel.GetHairRootOrientationsFromStrandTransform()
		if len( self.rootPositions ) > 0 and self.inputParamsDict['SlowMovingRoots']:
			self.solverUpdater.UpdateRoots( self.hairModel.rootParticles, self.rootPositions, self.rootOrientations )

		if not self.inputParamsDict['SlowMovingRoots'] and self.inputParamsDict['SubstepCount'] > 1:
			self.interpolatedRootPositions = moov.InterpolatePositions( self.oldRootPositions, self.rootPositions, self.inputParamsDict['SubstepCount'] - 1 )
			self.interpolatedRootOrientations = None if self.rootOrientations is None else moov.InterpolateOrientations( self.oldRootOrientations, self.rootOrientations, self.inputParamsDict['SubstepCount'] - 1 )

		# Release attachment
		if self.oldTime + step > self.inputParamsDict['AttachmentReleaseTime']:
			self.hairModel.ReleaseAttachment()


	def StepUpdate( self, step, outputHair ):
		'''Called once for each substep, multiple times per frame depending on the substep count.'''

		# Update meshes
		meshPositions = []
		meshRotationsReal = []
		meshRotationsImag = []
		for mesh in outputHair.GetPolygonMeshes():
			meshPositions.append( mesh.position )
			meshRotationsImag.append( mesh.rotationQImag )
			meshRotationsReal.append( mesh.rotationQReal )

		if len( self.substepRootPositions ) > 0 and not self.inputParamsDict['SlowMovingRoots']:
			#rootOrientations = self.hairModel.GetHairRootOrientationsFromStrandTransform()
			#rootOrientations = self.hairModel.GetHairRootOrientations( self.substepRootPositions, self.oldRootOrientations )
			self.solverUpdater.UpdateRoots( self.hairModel.rootParticles, self.substepRootPositions, self.substepRootOrientations )
		self.solver.UpdateParticles( self.hairModel.meshParticles, lambda pd: self.updateMesh( pd, meshPositions, meshRotationsImag, meshRotationsReal ), moov.ParticleInformation( moov.ParticleInformation.Position | moov.ParticleInformation.Rotation ) )

		# Update dynamic particles
		if len( self.hairModel.dynamicParticles ) > 0:
			self.solverUpdater.UpdateParticles( step, self.hairModel.dynamicParticles )

		# Step
		self.solver.Step( step )


	def PostStep( self, outputHair ):
		'''Called once at the end of each frame.'''
		self.hairModel.UpdateHair( outputHair )
		self.oldRootPositions = self.rootPositions
		self.oldRootOrientations = self.rootOrientations


	def HasChangedParameters( self, newParametersDict, paramNameList ):
		for paramName in paramNameList:
			if newParametersDict[paramName] != self.oldInputParamsDict[paramName]:
				return True
		return False


	def UpdateParameters( self, newParametersDict ):
		'''Returns True if parameters were updated, and False if a full simulation reset is needed.'''
		result = True

		if self.HasChangedParameters( newParametersDict, solverParams ):
			self.hairModel.SetSolverParameters( positionIterCount = newParametersDict['IterationCount'], maxSpeed = newParametersDict['VelocityLimit'], collisionTolerance = newParametersDict['CollisionTolerance'] )

		if self.HasChangedParameters( newParametersDict, collisionParams ):
			self.hairModel.params.SetCapsuleParameters( capsuleRadius = newParametersDict['CollideWithHairRadius'] )
			self.hairModel.EnableCollisions( newParametersDict['CollideWithBaseMesh'], newParametersDict['CollideWithMeshes'], newParametersDict['CollideWithHair'] )

		if self.HasChangedParameters( newParametersDict, groupHolderParams ):
			self.hairModel.params.SetGroupHolderParameters( useGroupHolder = newParametersDict['UseGroupHolder'], groupHolderPosMin = newParametersDict['GroupHolderPosMin'], 
				groupHolderPosMax = newParametersDict['GroupHolderPosMax'], groupHolderMaxGroupCount = newParametersDict['GroupMaxConstraints'], 
				groupHolderRandomSeed = newParametersDict['GroupRandomSeed'], groupHolderStiffness = newParametersDict['GroupStiffness'] )
			self.hairModel.ResetGroupHolder()

		if self.HasChangedParameters( newParametersDict, updateKernelParams ):
			self.solverUpdater.SetUpdateKernelParams( newParametersDict['Gravity'] / newParametersDict['GravityScale'], newParametersDict['Drag'] )

		# TODO: update stiffnesses
		if self.HasChangedParameters( newParametersDict, stiffnessParams ) or self.HasChangedParameters( newParametersDict, resetParams ):
			result = False

		self.oldInputParamsDict = newParametersDict.copy()
		self.inputParamsDict = newParametersDict

		return result


	def UpdateSolverFromCapture( self, hair, capture ):
		result = True
		if not ( capture is None or capture.IsEmpty() ):
			result = self.solver.UpdateObjectsFromCapture( capture, moov.ParticleInformation.All, moov.ConstraintInformation.All, resetKinematicData = True, exitOnError = True )
			if not result:
				print( "Error: Initial state capture cannot be used (changed hair model?); create a new capture" )
				self.Reset( hair )
			self.hairModel.UpdateHair( hair )
		return result

	@staticmethod
	def InterpolateList( list0, list1, param ):
		return [ x0 * ( 1.0 - param ) + x1 * param for x0, x1 in zip( list0, list1 ) ]

	def ComputeSubstepRootPosition( self, substepIndex, substepCount ):
		if substepIndex == substepCount - 1 or self.inputParamsDict['SlowMovingRoots'] or self.oldRootPositions is None:
			self.substepRootPositions = self.rootPositions
			self.substepRootOrientations = self.rootOrientations
			return
		#parameter = float( substepIndex + 1 ) / substepCount
		#self.substepRootPositions = self.InterpolateList( self.oldRootPositions, self.rootPositions, parameter )
		startIndex = ( substepIndex + 1 ) * len( self.rootPositions )
		endIndex = ( substepIndex + 2 ) * len( self.rootPositions )
		self.substepRootPositions = self.interpolatedRootPositions[startIndex:endIndex]
		self.substepRootOrientations = None if self.interpolatedRootOrientations is None else self.interpolatedRootOrientations[startIndex:endIndex]


	def Evaluate( self, hair, time, resetSolver = False, resetHairModel = False, newInputParamsDict = None, initialStateCapture = None ):

		if self.isFirstEvaluation or resetSolver:
			self.Initialize( hair, newInputParamsDict )
			result = self.UpdateSolverFromCapture( hair, initialStateCapture )
			self.oldTime = time
			self.firstTime = time
			self.isFirstEvaluation = False
			return result

		parameterUpdateRequiresReset = not self.UpdateParameters( newInputParamsDict )

		isFirstFrame = time <= self.firstTime

		if isFirstFrame or resetHairModel:
			if parameterUpdateRequiresReset or resetHairModel:
				self.Reset( hair )
			result = self.UpdateSolverFromCapture( hair, initialStateCapture )
			self.oldTime = time
			self.firstTime = time
			self.frameTimeStep = None
			return result

		self.hairModel.SetHair( hair )
		self.hairModel.ValidateStrandIndexMaps()

		deltaT = time - self.oldTime

		if deltaT == 0:
			self.hairModel.UpdateHair( hair )
			return True

		if self.frameTimeStep is None:
			self.frameTimeStep = deltaT
		if abs( self.frameTimeStep - deltaT ) > 1e-3:
			print( "WARNING: MoovPhysics time step has changed since the start of the simulation. This could lead to inconsistent results." )

		substepCount = self.inputParamsDict['SubstepCount']
		step = deltaT / float( substepCount )

		hair.SetSubstepCount( substepCount )

		self.PreStep( time, deltaT, hair )

		for substepIndex in range( substepCount ):
			hair.SelectSubstep( substepIndex )
			self.ComputeSubstepRootPosition( substepIndex, substepCount )
			self.StepUpdate( step, hair )

		self.PostStep( hair )

		self.oldTime = time

		return True


# New MoovPhysics interface

def Evaluate( hair, time, resetSolver = False, resetHairModel = False, newInputParamsDict = None, initialStateCapture = None ):

	global _hairSimulator

	if _hairSimulator is None:
		_hairSimulator = HairSimulator( _solver )

	return _hairSimulator.Evaluate( hair, time, resetSolver, resetHairModel, newInputParamsDict, initialStateCapture )
